# -*- coding: utf-8 -*-
"""
nasbench301.example
-------------------
Demo: In an environment with patches applied, load v1.0 models and ConfigSpace,
and predict for a Genotype and a random ConfigSpace configuration.
"""

from __future__ import annotations
from collections import namedtuple

# Key: Importing automatically applies patches (no download triggered)
from patches import apply_all_patches, load_models, load_configspace_with_patches

# Explicitly call again (idempotent), for verbose output
apply_all_patches(verbose=True)


def smoke_test():
    # 1) Load models (if not downloaded, will auto-download to module directory)
    perf, rt = load_models(version="1.0")
    # 2) Load ConfigSpace (clean JSON + new API + fix hyperparameter validation)
    cs = load_configspace_with_patches()

    # Example Genotype for DARTS
    Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
    genotype_config = Genotype(
        normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1),
                ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)],
        normal_concat=[2, 3, 4, 5],
        reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1),
                ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)],
        reduce_concat=[2, 3, 4, 5]
    )

    # Sample a configuration from ConfigSpace
    cfg = cs.sample_configuration()

    print("==> Predict runtime and performance...")
    pred_geno = perf.predict(config=genotype_config, representation="genotype",     with_noise=True)
    pred_cfg  = perf.predict(config=cfg,              representation="configspace", with_noise=True)

    rt_geno = rt.predict(config=genotype_config, representation="genotype")
    rt_cfg  = rt.predict(config=cfg,              representation="configspace")

    print(f"Genotype -> perf={float(pred_geno):.6f}, runtime={float(rt_geno):.6f}")
    print(f"Config   -> perf={float(pred_cfg):.6f},  runtime={float(rt_cfg):.6f}")


if __name__ == "__main__":
    smoke_test()
